{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f4ff218d-f4eb-4e4b-9f4c-205351349fd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "import os\n",
    "import random\n",
    "import datasets\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "from tqdm.notebook import tqdm\n",
    "from datasets import load_dataset, load_from_disk, Dataset, DatasetDict\n",
    "from transformers import AutoTokenizer, AutoConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc6368fa-4fcd-4336-a548-f37bda3a5d2f",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Sentiment Datasets (following FinGPT v3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90c1b2b0-550c-47a7-8486-c7f5601bc0a1",
   "metadata": {},
   "source": [
    "### 1. FPB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0d79a0e4-f8f8-4e20-811e-42256b0362df",
   "metadata": {},
   "outputs": [],
   "source": [
    "dic = {\n",
    "    0:\"negative\",\n",
    "    1:'neutral',\n",
    "    2:'positive',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "85a3d5c8-6a06-40e5-8ae7-06d1fd93f479",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 3634\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# fpb_datasets = load_dataset(\"financial_phrasebank\", \"sentences_50agree\")\n",
    "fpb_datasets = load_from_disk('../data/financial_phrasebank-sentences_50agree/')\n",
    "fpb_datasets = fpb_datasets[\"train\"]\n",
    "fpb_datasets = fpb_datasets.to_pandas()\n",
    "fpb_datasets.columns = [\"input\", \"output\"]\n",
    "fpb_datasets[\"output\"] = fpb_datasets[\"output\"].apply(lambda x:dic[x])\n",
    "fpb_datasets[\"instruction\"] = \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\"\n",
    "fpb_datasets = datasets.Dataset.from_pandas(fpb_datasets)\n",
    "fpb_datasets = fpb_datasets.train_test_split(seed=42)['train']\n",
    "fpb_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8d8d28e4-f380-423b-8a1b-51bfcb0b7004",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 21804\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = datasets.concatenate_datasets([fpb_datasets]*6)   # we want each data source have similar number of samples\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9097f1f3-598c-47d3-8b96-a08254da4c4e",
   "metadata": {},
   "source": [
    "### 2. FiQA SA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "55d618cc-9e05-4053-b354-e7b159c6906b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_label(x):\n",
    "    if x < - 0.1:\n",
    "        return \"negative\"\n",
    "    elif -0.1 <= x < 0.1:\n",
    "        return \"neutral\"\n",
    "    else:\n",
    "        return \"positive\"\n",
    "\n",
    "def add_instructions(x):\n",
    "    if x == \"post\":\n",
    "        return \"What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.\"\n",
    "    else:\n",
    "        return \"What is the sentiment of this news? Please choose an answer from {negative/neutral/positive}.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b81fb40-1275-4205-9d42-dbfc425dd4a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 938\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dataset = load_dataset('pauri32/fiqa-2018')\n",
    "dataset = load_from_disk('../data/fiqa-2018/')\n",
    "dataset = datasets.concatenate_datasets([dataset[\"train\"], dataset[\"validation\"] ,dataset[\"test\"] ])\n",
    "dataset = dataset.to_pandas()\n",
    "dataset[\"output\"] = dataset.sentiment_score.apply(make_label)\n",
    "dataset[\"instruction\"] = dataset.format.apply(add_instructions)\n",
    "dataset = dataset[['sentence', 'output', \"instruction\"]]\n",
    "dataset.columns = [\"input\", \"output\", \"instruction\"]\n",
    "dataset = datasets.Dataset.from_pandas(dataset)\n",
    "dataset = dataset.train_test_split(0.226, seed=42)['train']\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "80349c7a-3c0c-4c99-b2c3-f1b2dd377e64",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19698\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 41502\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp_dataset = datasets.concatenate_datasets([dataset]*21)\n",
    "train_dataset = datasets.concatenate_datasets([train_dataset, tmp_dataset]) \n",
    "print(tmp_dataset.num_rows)\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4887301-4f6a-49b0-9e42-aa5bdcd19856",
   "metadata": {},
   "source": [
    "### 3. TFNS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e707ffbf-7f92-4b45-8570-a6d63bc64bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dic = {\n",
    "    0:\"negative\",\n",
    "    1:'positive',\n",
    "    2:'neutral',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "41388d35-c6ae-4da8-b92d-13fb0b1db21e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 9543\n",
       "})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# social_media_dataset = load_dataset('zeroshot/twitter-financial-news-sentiment')\n",
    "social_media_dataset = load_from_disk('../data/twitter-financial-news-sentiment')\n",
    "social_media_dataset = social_media_dataset['train']\n",
    "social_media_dataset = social_media_dataset.to_pandas()\n",
    "social_media_dataset['label'] = social_media_dataset['label'].apply(lambda x:dic[x])\n",
    "social_media_dataset['instruction'] = 'What is the sentiment of this tweet? Please choose an answer from {negative/neutral/positive}.'\n",
    "social_media_dataset.columns = ['input', 'output', 'instruction']\n",
    "social_media_dataset = datasets.Dataset.from_pandas(social_media_dataset)\n",
    "social_media_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9f1ef9b7-af0c-407b-b270-7c4b2180036c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19086\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 60588\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tmp_dataset = datasets.concatenate_datasets([social_media_dataset]*2)\n",
    "train_dataset = datasets.concatenate_datasets([train_dataset,tmp_dataset]) \n",
    "print(tmp_dataset.num_rows)\n",
    "train_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aee85f1d-7c0e-4642-adbd-f145adb23ef1",
   "metadata": {},
   "source": [
    "### 4. NWGI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "642fcaf0-b909-4562-a913-877557d829bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 16184\n",
       "})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# finance_dataset = load_dataset('oliverwang15/news_with_gpt_instructions')\n",
    "finance_dataset = load_from_disk('../data/news_with_gpt_instructions/')\n",
    "finance_dataset = finance_dataset['train'].to_pandas()\n",
    "finance_dataset['output'] = finance_dataset['label']\n",
    "finance_dataset[\"input\"] = finance_dataset[\"news\"]\n",
    "finance_dataset[\"instruction\"] = 'What is the sentiment of this news? Please choose an answer from {strong negative/moderately negative/mildly negative/neutral/mildly positive/moderately positive/strong positive}.'\n",
    "finance_dataset = finance_dataset[['input', 'output', 'instruction']]\n",
    "finance_dataset = datasets.Dataset.from_pandas(finance_dataset)\n",
    "finance_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "98e24280-119d-4748-8c1d-709d7d3cf50d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(76772, 3)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = datasets.concatenate_datasets([train_dataset, finance_dataset])\n",
    "all_dataset = train_dataset.shuffle(seed=42)\n",
    "all_dataset.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a8f4d2a8-c0ac-49ce-8933-cce0de14f99f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from huggingface_hub import notebook_login\n",
    "# notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "4509a524-d694-4bcf-813c-4a28551bed39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_dataset.push_to_hub(\"fingpt_chatglm2_sentiment_instruction_lora_ft_dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b7774a58-e4e6-49ac-81fb-ef72dc46446e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "76772"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "21804 + 19698 + 19086 + 16184"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "1c740a64-d41d-43de-b20b-7826f01e9421",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004731893539428711,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 76772,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24c8597611564c9a8f2aa8055b126bf7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/76772 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "all_dataset.save_to_disk('fingpt-sentiment-train')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26ff88aa-2c9e-4627-8e5f-4517809bc192",
   "metadata": {},
   "source": [
    "## CLS Ver. For Zero-shot Assessment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "363744e9-cecb-45b8-b2ea-df22ff786e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "chatgpt_templates = [\n",
    "    \"What is the sentiment of the input {type} from financial perspective?\",\n",
    "    \"Assign a sentiment category to this {type} related to finance.\",\n",
    "    \"Categorize the input {type}'s emotional tone into one of three groups.\",\n",
    "    \"Determine the sentiment expressed in the {type} from financial perspective.\",\n",
    "    \"Characterize the {type}'s sentiment using the following options.\",\n",
    "]\n",
    "\n",
    "with open('../benchmarks/sentiment_templates.txt', 'w') as f:\n",
    "    f.writelines([l + '\\n' for l in chatgpt_templates])\n",
    "\n",
    "\n",
    "def option_list(ops_str):\n",
    "    options = ops_str.split('/')\n",
    "    random.shuffle(options)\n",
    "    return \", \".join(options)\n",
    "\n",
    "\n",
    "def map_func(feature):\n",
    "    t = re.search(r'tweet|news', feature['instruction']).group()\n",
    "    ops = option_list(\"negative/positive\")\n",
    "    if 'positive' in feature['output']:\n",
    "        output = 'positive'\n",
    "    elif 'negative' in feature['output']:\n",
    "        output = 'negative' \n",
    "    else:\n",
    "        output = 'neutral'\n",
    "    return {\"instruction\": random.choice(chatgpt_templates).format(type=t) + \"\\nOptions: \" + ops, \"output\": output}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aca0c743-68b1-4d4d-b75b-5e51d36c2318",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004678487777709961,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Map",
       "rate": null,
       "total": 76772,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2798b09bd1ff4167bd432f92d26503f0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/76772 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004431486129760742,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Filter",
       "rate": null,
       "total": 76772,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c41f028e5dd24a1ca7eaa920f3a480ab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Filter:   0%|          | 0/76772 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 47557\n",
       "})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(0)\n",
    "all_dataset_instruct = all_dataset.map(map_func)\n",
    "all_dataset_instruct = all_dataset_instruct.filter(lambda x: x['output'] != 'neutral')\n",
    "all_dataset_instruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "39a9f66a-e9e4-4392-8e7e-1f3348dc0be3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004517316818237305,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 47557,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "651b706e66bd43b2a2184b663ec0aa21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/47557 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "['Determine the sentiment expressed in the news from financial perspective.\\nOptions: negative, positive',\n",
       " 'Determine the sentiment expressed in the tweet from financial perspective.\\nOptions: negative, positive',\n",
       " \"Characterize the news's sentiment using the following options.\\nOptions: negative, positive\",\n",
       " \"Characterize the news's sentiment using the following options.\\nOptions: negative, positive\",\n",
       " 'What is the sentiment of the input tweet from financial perspective?\\nOptions: positive, negative',\n",
       " 'Determine the sentiment expressed in the tweet from financial perspective.\\nOptions: negative, positive',\n",
       " \"Categorize the input tweet's emotional tone into one of three groups.\\nOptions: negative, positive\",\n",
       " \"Characterize the news's sentiment using the following options.\\nOptions: positive, negative\",\n",
       " 'Determine the sentiment expressed in the news from financial perspective.\\nOptions: negative, positive',\n",
       " 'What is the sentiment of the input tweet from financial perspective?\\nOptions: negative, positive']"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_dataset_instruct.save_to_disk('fingpt-sentiment-cls-instruct')\n",
    "all_dataset_instruct['instruction'][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a9357d-7b48-46f8-9ebe-cfcb36dec1d6",
   "metadata": {},
   "source": [
    "# Headline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2b71a977-3ed7-459d-97bf-878d086da1f8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Dates</th>\n",
       "      <th>URL</th>\n",
       "      <th>News</th>\n",
       "      <th>Price or Not</th>\n",
       "      <th>Price Direction Up</th>\n",
       "      <th>Price Direction Constant</th>\n",
       "      <th>Price Direction Down</th>\n",
       "      <th>PastPrice</th>\n",
       "      <th>FuturePrice</th>\n",
       "      <th>PastNews</th>\n",
       "      <th>FutureNews</th>\n",
       "      <th>Asset Comparision</th>\n",
       "      <th>Price Sentiment</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>28-01-2016</td>\n",
       "      <td>http://www.marketwatch.com/story/april-gold-do...</td>\n",
       "      <td>april gold down 20 cents to settle at $1,116.1...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>13-09-2017</td>\n",
       "      <td>http://www.marketwatch.com/story/gold-prices-s...</td>\n",
       "      <td>gold suffers third straight daily decline</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>negative</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>26-07-2016</td>\n",
       "      <td>http://www.marketwatch.com/story/gold-futures-...</td>\n",
       "      <td>Gold futures edge up after two-session decline</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>28-02-2018</td>\n",
       "      <td>https://www.metalsdaily.com/link/277199/dent-r...</td>\n",
       "      <td>dent research : is gold's day in the sun comin...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>none</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>06-09-2017</td>\n",
       "      <td>http://www.marketwatch.com/story/gold-steadies...</td>\n",
       "      <td>Gold snaps three-day rally as Trump, lawmakers...</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>positive</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Dates                                                URL  \\\n",
       "0  28-01-2016  http://www.marketwatch.com/story/april-gold-do...   \n",
       "1  13-09-2017  http://www.marketwatch.com/story/gold-prices-s...   \n",
       "2  26-07-2016  http://www.marketwatch.com/story/gold-futures-...   \n",
       "3  28-02-2018  https://www.metalsdaily.com/link/277199/dent-r...   \n",
       "4  06-09-2017  http://www.marketwatch.com/story/gold-steadies...   \n",
       "\n",
       "                                                News  Price or Not  \\\n",
       "0  april gold down 20 cents to settle at $1,116.1...             1   \n",
       "1          gold suffers third straight daily decline             1   \n",
       "2     Gold futures edge up after two-session decline             1   \n",
       "3  dent research : is gold's day in the sun comin...             0   \n",
       "4  Gold snaps three-day rally as Trump, lawmakers...             1   \n",
       "\n",
       "   Price Direction Up  Price Direction Constant  Price Direction Down  \\\n",
       "0                   0                         0                     1   \n",
       "1                   0                         0                     1   \n",
       "2                   1                         0                     0   \n",
       "3                   0                         0                     0   \n",
       "4                   1                         0                     0   \n",
       "\n",
       "   PastPrice  FuturePrice  PastNews  FutureNews  Asset Comparision  \\\n",
       "0          1            0         0           0                  0   \n",
       "1          1            0         0           0                  0   \n",
       "2          1            0         0           0                  0   \n",
       "3          0            0         0           1                  0   \n",
       "4          1            0         0           0                  0   \n",
       "\n",
       "  Price Sentiment  \n",
       "0        negative  \n",
       "1        negative  \n",
       "2        positive  \n",
       "3            none  \n",
       "4        positive  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('gold-dataset-sinha-khandait.csv')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2e848299-c202-4163-b46a-336af9b98495",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9129\n",
      "2283\n"
     ]
    }
   ],
   "source": [
    "train_dataset, test_dataset = {}, {}\n",
    "inputs, outputs, instructions = [], [], []\n",
    "for index, row in df.iterrows():\n",
    "    \n",
    "    if index + 1 >= len(df) * 0.8 and not train_dataset:\n",
    "        train_dataset['input'] = inputs\n",
    "        train_dataset['output'] = outputs\n",
    "        train_dataset['instruction'] = instructions\n",
    "        inputs, outputs, instructions = [], [], []\n",
    "        \n",
    "    inputs.extend([row['News']] * 9)\n",
    "    # price or not\n",
    "    instructions.append('Does the news headline talk about price? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['Price Direction Constant'] else 'No')\n",
    "    # price up\n",
    "    instructions.append('Does the news headline talk about price going up? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['Price Direction Up'] else 'No')\n",
    "    # price stable\n",
    "    instructions.append('Does the news headline talk about price staying constant? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['Price Direction Constant'] else 'No')\n",
    "    # price down\n",
    "    instructions.append('Does the news headline talk about price going down? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['Price Direction Down'] else 'No')\n",
    "    # past price\n",
    "    instructions.append('Does the news headline talk about price in the past? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['PastPrice'] else 'No')\n",
    "    # future price\n",
    "    instructions.append('Does the news headline talk about price in the future? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['FuturePrice'] else 'No')\n",
    "    # past general\n",
    "    instructions.append('Does the news headline talk about a general event (apart from prices) in the past? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['PastNews'] else 'No')\n",
    "    # future general\n",
    "    instructions.append('Does the news headline talk about a general event (apart from prices) in the future? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['FutureNews'] else 'No')\n",
    "    # asset comparison\n",
    "    instructions.append('Does the news headline compare gold with any other asset? Please choose an answer from {Yes/No}.')\n",
    "    outputs.append('Yes' if row['Asset Comparision'] else 'No')\n",
    "    \n",
    "test_dataset['input'] = inputs\n",
    "test_dataset['output'] = outputs\n",
    "test_dataset['instruction'] = instructions\n",
    "\n",
    "print(len(train_dataset['input']) // 9)\n",
    "print(len(test_dataset['input']) // 9)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "6ae82176-fd34-43b9-a3ad-7c85bf35ee08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004788875579833984,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 82161,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "670ca6e61e0141cfbe00f348caa2bdf5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/82161 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004403829574584961,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 20547,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33fd0bb8a0d4403eb67aa4c002a5ffc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/20547 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 82161\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 20547\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "headline_dataset = DatasetDict({\n",
    "    'train': Dataset.from_dict(train_dataset),\n",
    "    'test': Dataset.from_dict(test_dataset)\n",
    "})\n",
    "headline_dataset.save_to_disk('fingpt-headline')\n",
    "headline_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64d9cfc3-02d6-4e46-8dfb-93d43b1bc8d5",
   "metadata": {},
   "source": [
    "## CLS Ver. for Zero-shot Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b72efdd8-e920-416d-9a25-8aa0c65769e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "chatgpt_templates = [\n",
    "    \"Does the news headline talk about {subject}? \",\n",
    "    \"Please determine if the news headline addresses {subject}.\",\n",
    "    \"In the context of the news headline, is {subject} discussed?\",\n",
    "    \"Is the news headline related to {subject}?\",\n",
    "    \"Let me know if the news headline talks about {subject}.\",\n",
    "    \"Consider the news headline - does it concern {subject}?\",\n",
    "    \"Examine the news headline and decide if it includes {subject}.\",\n",
    "    \"Assess if the news headline touches on {subject}.\",\n",
    "    \"Review the news headline and determine if it relates to {subject}.\",\n",
    "    \"Analyze the news headline for any mention of {subject}.\",\n",
    "    \"Interpret the news headline to see if it mentions {subject}.\"\n",
    "]\n",
    "\n",
    "def yes_no():\n",
    "    options = ['Yes', 'No']\n",
    "    random.shuffle(options)\n",
    "    return \", \".join(options)\n",
    "\n",
    "def map_func(feature):\n",
    "    match_res = re.search(r'talk about (.*)\\? Please', feature['instruction'])\n",
    "    subject = 'comparing gold with any other asset' if match_res is None else match_res.group(1)\n",
    "    return {\"instruction\": random.choice(chatgpt_templates).format(subject=subject) + \"\\nOptions: \" + yes_no()}\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "50c93371-343c-458f-8500-89e418c02969",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004721164703369141,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Map",
       "rate": null,
       "total": 82161,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a421304eaaad44dfac5e9e3da95a9791",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/82161 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0043904781341552734,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Map",
       "rate": null,
       "total": 20547,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "adab61cca62342e487bf14b46f017214",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/20547 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004418373107910156,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 82161,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ff2661ab7294006861eb7f827297b7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/82161 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004082918167114258,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 20547,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "11a9114fba2f4d3e85c1a4034e8e755c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/20547 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "random.seed(0)\n",
    "headline_dataset_instruct = headline_dataset.map(map_func)\n",
    "headline_dataset_instruct.save_to_disk('fingpt-headline-cls-instruct')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "00916647-6764-464e-b0f3-2299675da1f7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Examine the news headline and decide if it includes price.\\nOptions: Yes, No',\n",
       " 'Does the news headline talk about price going up? \\nOptions: Yes, No',\n",
       " 'Review the news headline and determine if it relates to price staying constant.\\nOptions: Yes, No',\n",
       " 'Examine the news headline and decide if it includes price going down.\\nOptions: Yes, No',\n",
       " 'Assess if the news headline touches on price in the past.\\nOptions: Yes, No',\n",
       " 'Analyze the news headline for any mention of price in the future.\\nOptions: No, Yes',\n",
       " 'Review the news headline and determine if it relates to a general event (apart from prices) in the past.\\nOptions: No, Yes',\n",
       " 'Let me know if the news headline talks about a general event (apart from prices) in the future.\\nOptions: No, Yes',\n",
       " 'Please determine if the news headline addresses comparing gold with any other asset.\\nOptions: Yes, No',\n",
       " 'Review the news headline and determine if it relates to price.\\nOptions: No, Yes']"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "headline_dataset_instruct['train']['instruction'][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "add5ca03-de6f-47be-9ebe-9c63e194b120",
   "metadata": {
    "tags": []
   },
   "source": [
    "# NER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "95459177-9eff-40e3-946d-035043e22650",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read and parse the CoNLL-2003 formatted dataset\n",
    "\n",
    "ent_dict = {\n",
    "    'PER': 'person',\n",
    "    'ORG': 'organization',\n",
    "    'LOC': 'location',\n",
    "}\n",
    "\n",
    "def read_conll_file(file_path):\n",
    "    sentences, tokens, labels = [], [], []\n",
    "    with open(file_path, 'r') as f:\n",
    "        sentence = []\n",
    "        for line in f:\n",
    "            if line.strip() == '':\n",
    "                if sentence:\n",
    "                    sentences.append(sentence)\n",
    "                    sentence = []\n",
    "            else:\n",
    "                parts = line.strip().split()\n",
    "                token, label = parts[0], parts[-1]\n",
    "                tokens.append(token)\n",
    "                labels.append(label)\n",
    "                sentence.append((token, label))\n",
    "                \n",
    "    return sentences\n",
    "                \n",
    "\n",
    "def get_ner_dataset(sentences):\n",
    "    \n",
    "    inputs, outputs, instructions = [], [], []\n",
    "    count = {'person': 0, 'organization': 0, 'location': 0}\n",
    "    for sentence in sentences:\n",
    "        is_entity = [tup[1] != 'O' and not tup[1].endswith('MISC') for tup in sentence]\n",
    "        if sum(is_entity) == 0:\n",
    "            continue\n",
    "        instructions.append('Please extract entities and their types from the input sentence, entity types should be chosen from {person/organization/location}.')\n",
    "        inputs.append(' '.join([tup[0] for tup in sentence]))\n",
    "        outputs.append('')\n",
    "        tmp_tup_list = []\n",
    "        for i, tup in enumerate(sentence):\n",
    "            if tmp_tup_list and (not is_entity[i] or tmp_tup_list[-1][1] != tup[1] or i + 1 == len(sentence)):\n",
    "                entity = ' '.join([t[0] for t in tmp_tup_list])\n",
    "                assert tmp_tup_list[0][1] == tmp_tup_list[-1][1], tmp_tup_list\n",
    "                entity_type = ent_dict[tmp_tup_list[-1][1].split('-')[-1]]\n",
    "                a = 'an' if entity_type == 'organization' else 'a'\n",
    "                outputs[-1] += f'{entity} is {a} {entity_type}, ' \n",
    "                tmp_tup_list = [] if not is_entity[i] else [tup]\n",
    "                count[entity_type] += 1\n",
    "            elif is_entity[i]:\n",
    "                tmp_tup_list.append(tup)\n",
    "            else:\n",
    "                pass\n",
    "        outputs[-1] = outputs[-1].strip(', ') + '.'\n",
    "    \n",
    "    print(len(instructions))\n",
    "    print(count)\n",
    "        \n",
    "    return {\"input\": inputs, \"output\": outputs, \"instruction\": instructions}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9aeec446-04ce-44d5-abea-38b65aece991",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "511\n",
      "{'person': 745, 'organization': 243, 'location': 168}\n",
      "98\n",
      "{'person': 216, 'organization': 56, 'location': 39}\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.00464177131652832,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 511,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6de323eafdf649b8af9b322dc33d0fcb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/511 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004297494888305664,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 98,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "834f28cf02724a729c9421e7b29349c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/98 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 511\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 98\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data = read_conll_file('./SEC-filings/CONLL-format/data/train/FIN5.txt')\n",
    "test_data = read_conll_file('./SEC-filings/CONLL-format/data/test/FIN3.txt')\n",
    "\n",
    "train_data = get_ner_dataset(train_data)\n",
    "test_data = get_ner_dataset(test_data)\n",
    "\n",
    "ner_dataset = DatasetDict({\n",
    "    'train': Dataset.from_dict(train_data),\n",
    "    'test': Dataset.from_dict(test_data)\n",
    "})\n",
    "ner_dataset.save_to_disk('fingpt-ner')\n",
    "ner_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "32c6d6c2-3d15-4382-bc0a-5c1d5b17fc00",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_ner_dataset_1(sentences):\n",
    "    \n",
    "#     inputs, outputs, instructions = [], [], []\n",
    "#     count = {'person': 0, 'organization': 0, 'location': 0}\n",
    "#     for sentence in sentences:\n",
    "#         is_entity = [tup[1] != 'O' and not tup[1].endswith('MISC') for tup in sentence]\n",
    "#         if sum(is_entity) == 0:\n",
    "#             continue\n",
    "#         instructions.append('Please list all entities in the input text that fit the following entity types: \"person\", \"organization\", \"location\". Output format is \"type1: entity1; type2: entity2\"')\n",
    "#         instructions.append('Given options of entity types, please find all the entities associated with them in the input text. answer with format \"entity1: type1; entity2: type2\".\\nOptions: \"person\", \"organization\", \"location\".')\n",
    "#         inputs.append(' '.join([tup[0] for tup in sentence]))\n",
    "#         inputs.append(' '.join([tup[0] for tup in sentence]))\n",
    "#         outputs.append('')\n",
    "#         outputs.append('')\n",
    "#         tmp_tup_list = []\n",
    "#         for i, tup in enumerate(sentence):\n",
    "#             if tmp_tup_list and (not is_entity[i] or tmp_tup_list[-1][1] != tup[1] or i + 1 == len(sentence)):\n",
    "#                 entity = ' '.join([t[0] for t in tmp_tup_list])\n",
    "#                 assert tmp_tup_list[0][1] == tmp_tup_list[-1][1], tmp_tup_list\n",
    "#                 entity_type = ent_dict[tmp_tup_list[-1][1].split('-')[-1]]\n",
    "#                 outputs[-2] += f'{entity_type}: {entity}; ' \n",
    "#                 outputs[-1] += f'{entity}: {entity_type}; ' \n",
    "#                 tmp_tup_list = [] if not is_entity[i] else [tup]\n",
    "#                 count[entity_type] += 1\n",
    "#             elif is_entity[i]:\n",
    "#                 tmp_tup_list.append(tup)\n",
    "#             else:\n",
    "#                 pass\n",
    "#         outputs[-1] = outputs[-1].strip('; ')\n",
    "#         outputs[-2] = outputs[-2].strip('; ')\n",
    "    \n",
    "#     print(len(instructions))\n",
    "#     print(count)\n",
    "        \n",
    "#     return {\"input\": inputs, \"output\": outputs, \"instruction\": instructions}\n",
    "\n",
    "\n",
    "# def get_ner_dataset_2(sentences):\n",
    "    \n",
    "#     inputs, outputs, instructions = [], [], []\n",
    "#     count = {'person': 0, 'organization': 0, 'location': 0}\n",
    "#     templates = [\n",
    "#         'Does the input text include any entity of type \"{entity_type}\" ? If so, list them all, and answer with format \"entity1; entity2\". If not, Answer \"No\".',\n",
    "#         'Please tell me all the entities in the text that belong to the given category \"{entity_type}\". Output format is \"entity1; entity2\". If no such entity can be found, answer \"No\".',\n",
    "#     ]\n",
    "#     for sentence in sentences:\n",
    "#         is_entity = [tup[1] != 'O' and not tup[1].endswith('MISC') for tup in sentence]\n",
    "#         if sum(is_entity) == 0:\n",
    "#             continue\n",
    "#         for template in templates:\n",
    "#             for tgt_entity_type in ['person', 'location', 'organization']:\n",
    "#                 instructions.append(template.format(entity_type=tgt_entity_type))\n",
    "#                 inputs.append(' '.join([tup[0] for tup in sentence]))\n",
    "#                 outputs.append('')\n",
    "#                 tmp_tup_list = []\n",
    "#                 for i, tup in enumerate(sentence):\n",
    "#                     if tmp_tup_list and (not is_entity[i] or tmp_tup_list[-1][1] != tup[1] or i + 1 == len(sentence)):\n",
    "#                         entity = ' '.join([t[0] for t in tmp_tup_list])\n",
    "#                         assert tmp_tup_list[0][1] == tmp_tup_list[-1][1], tmp_tup_list\n",
    "#                         entity_type = ent_dict[tmp_tup_list[-1][1].split('-')[-1]]\n",
    "#                         tmp_tup_list = [] if not is_entity[i] else [tup]\n",
    "#                         if entity_type == tgt_entity_type:\n",
    "#                             outputs[-1] += f'{entity}; ' \n",
    "#                             count[entity_type] += 1\n",
    "#                     elif is_entity[i]:\n",
    "#                         tmp_tup_list.append(tup)\n",
    "#                     else:\n",
    "#                         pass\n",
    "#                 outputs[-1] = 'No' if not outputs[-1] else outputs[-1].strip('; ')\n",
    "    \n",
    "#     print(len(instructions))\n",
    "#     print(count)\n",
    "        \n",
    "#     return {\"input\": inputs, \"output\": outputs, \"instruction\": instructions}\n",
    "\n",
    "\n",
    "# def get_ner_dataset_3(sentences):\n",
    "    \n",
    "#     inputs, outputs, instructions = [], [], []\n",
    "#     count = {'person': 0, 'organization': 0, 'location': 0}\n",
    "#     for sentence in sentences:\n",
    "#         is_entity = [tup[1] != 'O' and not tup[1].endswith('MISC') for tup in sentence]\n",
    "#         if sum(is_entity) == 0:\n",
    "#             continue\n",
    "#         prev_count = count.copy()\n",
    "#         instructions.append('Please find all entities in the input text that fit the following entity types: \"person\", \"organization\", \"location\", then answer with the number counted for each type. Output format is \"type1: number1; type2: number2\"')\n",
    "#         instructions.append('Given entity types as options, please find the number of occurrence for each entity type in the input text. answer with format \"type1: number1; type2: number2\".\\nOptions: \"person\", \"organization\", \"location\".')\n",
    "#         inputs.extend([' '.join([tup[0] for tup in sentence])] * 2)\n",
    "#         tmp_tup_list = []\n",
    "#         for i, tup in enumerate(sentence):\n",
    "#             if tmp_tup_list and (not is_entity[i] or tmp_tup_list[-1][1] != tup[1] or i + 1 == len(sentence)):\n",
    "#                 entity = ' '.join([t[0] for t in tmp_tup_list])\n",
    "#                 assert tmp_tup_list[0][1] == tmp_tup_list[-1][1], tmp_tup_list\n",
    "#                 entity_type = ent_dict[tmp_tup_list[-1][1].split('-')[-1]]\n",
    "#                 tmp_tup_list = [] if not is_entity[i] else [tup]\n",
    "#                 count[entity_type] += 1\n",
    "#             elif is_entity[i]:\n",
    "#                 tmp_tup_list.append(tup)\n",
    "#             else:\n",
    "#                 pass\n",
    "        \n",
    "#         per_cnt = count['person'] - prev_count['person']\n",
    "#         loc_cnt = count['location'] - prev_count['location']\n",
    "#         org_cnt = count['organization'] - prev_count['organization']\n",
    "#         output_str = f'person: {per_cnt}; location: {loc_cnt}; organization: {org_cnt}'\n",
    "#         outputs.extend([output_str] * 2)\n",
    "    \n",
    "#     print(len(instructions))\n",
    "#     print(count)\n",
    "        \n",
    "#     return {\"input\": inputs, \"output\": outputs, \"instruction\": instructions}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b00bbd11-7867-4a8e-8d9c-877d2b827160",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_data = read_conll_file('./SEC-filings/CONLL-format/data/train/FIN5.txt')\n",
    "# test_data = read_conll_file('./SEC-filings/CONLL-format/data/test/FIN3.txt')\n",
    "\n",
    "# train_data_0 = get_ner_dataset(train_data)\n",
    "# train_data_1 = get_ner_dataset_1(train_data)\n",
    "# train_data_2 = get_ner_dataset_2(train_data)\n",
    "# train_data_3 = get_ner_dataset_3(train_data)\n",
    "\n",
    "# test_data_0 = get_ner_dataset(test_data)\n",
    "# test_data_1 = get_ner_dataset_1(test_data)\n",
    "# test_data_2 = get_ner_dataset_2(test_data)\n",
    "# test_data_3 = get_ner_dataset_3(test_data)\n",
    "\n",
    "# train_data = {k: train_data_0[k] + train_data_1[k] + train_data_2[k] + train_data_3[k] for k in train_data_0.keys()} \n",
    "# test_data = {k: test_data_0[k] + test_data_1[k] + test_data_2[k] + test_data_3[k] for k in test_data_0.keys()} \n",
    "# ner_dataset = DatasetDict({\n",
    "#     'train': Dataset.from_dict(train_data),\n",
    "#     'test': Dataset.from_dict(test_data)\n",
    "# })\n",
    "# ner_dataset.save_to_disk('fingpt-ner-full')\n",
    "# ner_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b4559b61-c3eb-4a72-8a5c-aaa7bfd9d209",
   "metadata": {},
   "outputs": [],
   "source": [
    "# chatgpt_templates = {\n",
    "#     'Please list all entities in the input text that fit the following entity types: \"person\", \"organization\", \"location\". Output format is \"type1: entity1; type2: entity2\"':[\n",
    "#         \"Identify and compile all entities of types 'person,' 'organization,' and 'location' from the input text. Format the output as 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Gather all entities falling under the categories 'person,' 'organization,' and 'location' within the input text. Please use the 'type1: entity1; type2: entity2' format for the output.\",\n",
    "#         \"Extract entities categorized as 'person,' 'organization,' or 'location' from the input text. Organize the results in the format 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Compile a list of entities matching 'person,' 'organization,' and 'location' entity types from the input text. Ensure the output format remains 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Detect and list all entities that fall under the types 'person,' 'organization,' and 'location' in the input text. Format the output as 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Enumerate entities of types 'person,' 'organization,' and 'location' from the input text. Use the output format 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Identify and present entities categorized as 'person,' 'organization,' and 'location' from the input text. Keep the output in the format 'type1: entity1; type2: entity2.'\",\n",
    "#         \"Spot and document entities falling into the 'person,' 'organization,' and 'location' types within the input text. Keep the 'type1: entity1; type2: entity2' output format.\",\n",
    "#         \"Extract entities matching 'person,' 'organization,' and 'location' types from the input text. Ensure the output follows the 'type1: entity1; type2: entity2' format.\",\n",
    "#         \"Enumerate all entities that correspond to 'person,' 'organization,' and 'location' entity types within the input text. Format the output as 'type1: entity1; type2: entity2.'\"\n",
    "#     ],\n",
    "#     'Given options of entity types, please find all the entities associated with them in the input text. answer with format \"entity1: type1; entity2: type2\".\\nOptions: \"person\", \"organization\", \"location\".': [\n",
    "#         \"Using the provided options of entity types, locate all corresponding entities within the input text. Present the results in the format 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Find entities related to the given options of entity types within the input text. Respond with the format 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Identify entities associated with the provided entity type options in the input text. Format your response as 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Discover entities linked to the entity types listed in the options within the input text. Provide the output in the format 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Locate all entities corresponding to the provided entity type options in the input text. Present your findings as 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Search for entities related to the entity types listed in the options within the input text. Share the results in the 'entity1: type1; entity2: type2' format.\",\n",
    "#         \"Identify entities associated with the given entity type options in the input text. Use the format 'entity1: type1; entity2: type2' for your response.\",\n",
    "#         \"Retrieve entities connected to the entity types mentioned in the options from the input text. Format your answer as 'entity1: type1; entity2: type2.'\",\n",
    "#         \"Examine the input text for entities that match the provided entity type options. Respond using the 'entity1: type1; entity2: type2' format.\",\n",
    "#         \"Scour the input text for entities corresponding to the entity types in the options. Present the results in the 'entity1: type1; entity2: type2' format.\",\n",
    "#     ],\n",
    "#     'Does the input text include any entity of type \"{entity_type}\" ? If so, list them all, and answer with format \"entity1; entity2\". If not, Answer \"No\".': [\n",
    "#         \"Check if the input text contains any entities of type '{entity_type}'. If found, list them all in the format 'entity1; entity2.' Otherwise, respond with 'No.'\",\n",
    "#         \"Examine the input text for the presence of entities belonging to the '{entity_type}' type. If any are found, provide a list in the format 'entity1; entity2.' If none are found, answer 'No.'\",\n",
    "#         \"Verify whether the input text contains entities of the type '{entity_type}.' If so, enumerate them in the format 'entity1; entity2.' If not, reply with 'No.'\",\n",
    "#         \"Determine if the input text includes any entities categorized as '{entity_type}.' If it does, present them in the format 'entity1; entity2.' In case of none, respond with 'No.'\",\n",
    "#         \"Assess whether the input text encompasses entities of the '{entity_type}' type. If there are any, list them as 'entity1; entity2.' If there are none, respond with 'No.'\",\n",
    "#         \"Inspect the input text to find out if it contains any entities labeled as '{entity_type}.' If it does, provide a list in the format 'entity1; entity2.' If not, reply with 'No.'\",\n",
    "#         \"Ascertain whether the input text possesses entities falling under the '{entity_type}' category. If it does, compile them in the format 'entity1; entity2.' If it doesn't, answer 'No.'\",\n",
    "#         \"Confirm whether the input text harbors entities of type '{entity_type}.' If it does, list them all as 'entity1; entity2.' If it doesn't, respond with 'No.'\",\n",
    "#         \"Investigate whether the input text features any entities of the '{entity_type}' variety. If any are present, document them in the format 'entity1; entity2.' Otherwise, answer 'No.'\",\n",
    "#         \"Scrutinize the input text to determine the presence of entities belonging to the '{entity_type}' category. If any are found, report them as 'entity1; entity2.' If none are found, respond with 'No.'\"\n",
    "#     ],\n",
    "#     'Please tell me all the entities in the text that belong to the given category \"{entity_type}\". Output format is \"entity1; entity2\". If no such entity can be found, answer \"No\".': [\n",
    "#         \"Identify and provide all entities within the text that fall under the specified category '{entity_type}'. Format the output as 'entity1; entity2.' If no such entity exists, respond with 'No.'\",\n",
    "#         \"Locate and list all entities in the text that are categorized as '{entity_type}'. Use the output format 'entity1; entity2.' If there are no entities of this type, answer 'No.'\",\n",
    "#         \"Find and present all entities within the text that belong to the designated category '{entity_type}'. Ensure the output maintains the format 'entity1; entity2.' If there are none, reply with 'No.'\",\n",
    "#         \"Discover and enumerate all entities in the text that are classified as '{entity_type}'. Format the output as 'entity1; entity2.' In the absence of such entities, respond with 'No.'\",\n",
    "#         \"Search for and compile all entities in the text that pertain to the specified category '{entity_type}'. Use the 'entity1; entity2' format for the output. If none are found, answer 'No.'\",\n",
    "#         \"Identify and gather all entities within the text that fit the given category '{entity_type}'. Present the results in the format 'entity1; entity2.' If there are no entities of this type, reply with 'No.'\",\n",
    "#         \"Detect and list all entities in the text that are associated with the provided category '{entity_type}'. Maintain the output format as 'entity1; entity2.' If none are found, respond with 'No.'\",\n",
    "#         \"Spot and document all entities within the text that are aligned with the specified category '{entity_type}'. Use the format 'entity1; entity2.' In the event of no such entities, answer 'No.'\",\n",
    "#         \"Examine and compile all entities in the text that fall under the given category '{entity_type}'. Keep the output in the 'entity1; entity2' format. If there are no entities of this type, reply with 'No.'\",\n",
    "#         \"Check and provide all entities in the text that match the category '{entity_type}'. Format the output as 'entity1; entity2.' If there are no entities of this type, answer 'No.'\",\n",
    "#     ],\n",
    "#     'Please find all entities in the input text that fit the following entity types: \"person\", \"organization\", \"location\", then answer with the number counted for each type. Output format is \"type1: number1; type2: number2\"': [\n",
    "#         \"Identify all entities within the input text falling under the specified entity types: 'person,' 'organization,' and 'location.' Report the counts for each type in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Detect entities of the 'person,' 'organization,' and 'location' types in the input text. Provide the respective counts for each type as 'type1: number1; type2: number2.'\",\n",
    "#         \"Locate entities categorized as 'person,' 'organization,' and 'location' within the input text. Deliver the counts for each type in the 'type1: number1; type2: number2' format.\",\n",
    "#         \"Search for entities that match the entity types 'person,' 'organization,' and 'location' in the input text. Present the counts for each type in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Gather entities belonging to the specified types 'person,' 'organization,' and 'location' from the input text. Respond with the counts for each type in the 'type1: number1; type2: number2' format.\",\n",
    "#         \"Examine the input text for entities of the 'person,' 'organization,' and 'location' categories. Report the respective counts for each type as 'type1: number1; type2: number2.'\",\n",
    "#         \"Check for entities falling under 'person,' 'organization,' and 'location' types in the input text. Provide the counts for each type in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Assess the input text for entities of types 'person,' 'organization,' and 'location.' Share the counts for each type in the 'type1: number1; type2: number2' format.\",\n",
    "#         \"Inspect the input text to find entities categorized as 'person,' 'organization,' and 'location.' Communicate the counts for each type using the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Identify entities matching 'person,' 'organization,' and 'location' types within the input text. Provide the counts for each type in the 'type1: number1; type2: number2' format.\"\n",
    "#     ],\n",
    "#     'Given entity types as options, please find the number of occurrence for each entity type in the input text. answer with format \"type1: number1; type2: number2\".\\nOptions: \"person\", \"organization\", \"location\".': [\n",
    "#         \"Using the provided entity type options, determine the count of each entity type within the input text. Respond in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Calculate the occurrences of each entity type listed in the options within the input text. Format the answer as 'type1: number1; type2: number2.'\",\n",
    "#         \"Count the instances of each entity type mentioned in the options within the input text. Provide the counts in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Find and tally the number of occurrences for each entity type from the given options in the input text. Answer using the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Examine the input text and enumerate the occurrences of each entity type specified in the options. Report the counts in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Search for and calculate the occurrences of each entity type included in the options within the input text. Present the results in the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Identify and count the number of times each entity type from the provided options appears in the input text. Share the counts using the format 'type1: number1; type2: number2.'\",\n",
    "#         \"Inspect the input text for occurrences of each entity type mentioned in the options. Report the counts in the 'type1: number1; type2: number2' format.\",\n",
    "#         \"Confirm the occurrences of each entity type listed in the options within the input text. Present the counts as 'type1: number1; type2: number2.'\",\n",
    "#         \"Gather the occurrences of each entity type specified in the options within the input text. Format the response as 'type1: number1; type2: number2.'\",\n",
    "#     ],\n",
    "# }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c8821ba4-0a0d-4446-814a-047bd17696f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# def get_augmented_dataset(data):\n",
    "#     instructions, inputs, outputs = [], [], []\n",
    "#     for i, ins in enumerate(data['instruction']):\n",
    "#         if ins not in chatgpt_templates:\n",
    "#             entity_type = re.search(r'location|organization|person', ins).group()\n",
    "#             ins = ins.replace(entity_type, '{entity_type}')\n",
    "#             assert ins in chatgpt_templates\n",
    "#         for ins_template in chatgpt_templates[ins]:\n",
    "#             if '{entity_type}' in ins_template:\n",
    "#                 ins_template = ins_template.replace('{entity_type}', entity_type)\n",
    "#             instructions.append(ins_template)\n",
    "#             inputs.append(data['input'][i])\n",
    "#             outputs.append(data['output'][i])\n",
    "#     return {\n",
    "#         'instruction': data['instruction'] + instructions,\n",
    "#         'input': data['input'] + inputs,\n",
    "#         'output': data['output'] + outputs\n",
    "#     }\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f9ce43c9-8f22-4a79-b5c9-042df702da2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_data_aug_1 = get_augmented_dataset(train_data_1)\n",
    "# train_data_aug_2 = get_augmented_dataset(train_data_2)\n",
    "# train_data_aug_3 = get_augmented_dataset(train_data_3)\n",
    "\n",
    "# test_data_aug_1 = get_augmented_dataset(test_data_1)\n",
    "# test_data_aug_2 = get_augmented_dataset(test_data_2)\n",
    "# test_data_aug_3 = get_augmented_dataset(test_data_3)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a915673a-3f3a-4b68-bdfb-e7e7a49dbb86",
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_instruct_data = {k: train_data_aug_1[k] + train_data_aug_2[k] + train_data_aug_3[k] for k in train_data_aug_1.keys()} \n",
    "# test_instruct_data = {k: test_data_aug_1[k] + test_data_aug_2[k] + test_data_aug_3[k] for k in test_data_aug_1.keys()} \n",
    "\n",
    "# ner_instruct_dataset = DatasetDict({\n",
    "#     'train': Dataset.from_dict(train_instruct_data),\n",
    "#     'test': Dataset.from_dict(test_instruct_data)\n",
    "# })\n",
    "# ner_instruct_dataset.save_to_disk('fingpt-ner-full-instruct')\n",
    "# ner_instruct_dataset\n",
    "# # train_data_aug_2['instruction'][-30:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbae408e-509b-4b74-80ff-d09e12f24f59",
   "metadata": {
    "tags": []
   },
   "source": [
    "## CLS Ver. for Zero-shot Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "dd3613bf-cd27-47ac-a72d-b5a90d5ecf55",
   "metadata": {},
   "outputs": [],
   "source": [
    "chatgpt_templates = [\n",
    "    \"What is the entity type of '{entity}' in the input sentence.\",\n",
    "    \"With the input text as context, identify the entity type of '{entity}'.\",\n",
    "    \"Using the input sentence as a reference, analyze and specify the entity type of '{entity}'.\",\n",
    "    \"In the context of the input sentence, examine and categorize the entity type of '{entity}'.\",\n",
    "    \"Utilize the input text as context to explore and ascertain the entity type of '{entity}'.\",\n",
    "    \"Leverage the input sentence to evaluate and define the entity type for '{entity}'.\",\n",
    "    \"Considering the input sentence as context, inspect and classify the entity type of '{entity}'.\",\n",
    "    \"With the input sentence as a backdrop, scrutinize and determine the entity type of '{entity}'.\",\n",
    "    \"Interpreting the input sentence as context, specify the entity type for '{entity}'.\",\n",
    "    \"Assessing the input sentence as context, label the entity type of '{entity}'.\",\n",
    "    \"In the input sentence, determine the entity type for '{entity}'.\",\n",
    "    \"Within the input text, identify the entity type of '{entity}'.\",\n",
    "    \"Analyze the input sentence to find the entity type of '{entity}'.\",\n",
    "    \"Check the input sentence for the entity type associated with '{entity}'.\",\n",
    "    \"Explore the input sentence to ascertain the entity type of '{entity}'.\",\n",
    "    \"Examine the input text to classify the entity type of '{entity}'.\",\n",
    "    \"Scrutinize the input sentence to define the entity type of '{entity}'.\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "e8b3d4fe-fa0e-44cb-917e-6677b331c98c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def entities():\n",
    "    options = ['location', 'organization', 'person']\n",
    "    random.shuffle(options)\n",
    "    return \", \".join(options)\n",
    "\n",
    "\n",
    "def get_ner_dataset_4(sentences):\n",
    "    \n",
    "    inputs, outputs, instructions = [], [], []\n",
    "    for sentence in sentences:\n",
    "        is_entity = [tup[1] != 'O' and not tup[1].endswith('MISC') for tup in sentence]\n",
    "        if sum(is_entity) == 0:\n",
    "            continue\n",
    "        tmp_tup_list = []\n",
    "        entity_dict = {}\n",
    "        for i, tup in enumerate(sentence):\n",
    "            if tmp_tup_list and (not is_entity[i] or tmp_tup_list[-1][1] != tup[1] or i + 1 == len(sentence)):\n",
    "                entity = ' '.join([t[0] for t in tmp_tup_list])\n",
    "                assert tmp_tup_list[0][1] == tmp_tup_list[-1][1], tmp_tup_list\n",
    "                entity_type = ent_dict[tmp_tup_list[-1][1].split('-')[-1]]\n",
    "                tmp_tup_list = [] if not is_entity[i] else [tup]\n",
    "                if entity not in entity_dict:\n",
    "                    entity_dict[entity] = entity_type\n",
    "                elif entity_dict[entity] != entity_type:\n",
    "                    entity_dict[entity] = \"\"\n",
    "                else:\n",
    "                    pass\n",
    "            elif is_entity[i]:\n",
    "                tmp_tup_list.append(tup)\n",
    "            else:\n",
    "                pass\n",
    "        # print(entity_dict)\n",
    "        for k, v in entity_dict.items():\n",
    "            # instructions.extend([t.format(entity=entity, options=entities()) for t in chatgpt_templates])\n",
    "            instructions.extend([t.format(entity=entity) + '\\nOptions: ' + entities() for t in chatgpt_templates])\n",
    "            inputs.extend([' '.join([tup[0] for tup in sentence])] * len(chatgpt_templates))\n",
    "            outputs.extend([entity_type] * len(chatgpt_templates))\n",
    "                \n",
    "    print(len(instructions))\n",
    "        \n",
    "    return {\"input\": inputs, \"output\": outputs, \"instruction\": instructions}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "421389f7-0174-44de-8e93-374ced33e949",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13549\n",
      "3502\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004518270492553711,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 13549,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d89aaa59b74485aa4c33dc64ba2637d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/13549 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0043218135833740234,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 3502,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de5e4a165e5e4d44a61247db629bbc1f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/3502 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 13549\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 3502\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data = read_conll_file('./SEC-filings/CONLL-format/data/train/FIN5.txt')\n",
    "test_data = read_conll_file('./SEC-filings/CONLL-format/data/test/FIN3.txt')\n",
    "\n",
    "train_data_cls = get_ner_dataset_4(train_data)\n",
    "test_data_cls = get_ner_dataset_4(test_data)\n",
    "\n",
    "ner_cls_instruct_dataset = DatasetDict({\n",
    "    'train': Dataset.from_dict(train_data_cls),\n",
    "    'test': Dataset.from_dict(test_data_cls)\n",
    "})\n",
    "ner_cls_instruct_dataset.save_to_disk('fingpt-ner-cls-instruct')\n",
    "ner_cls_instruct_dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "e5178c4a-5add-4828-bca5-549b0edb1090",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"What is the entity type of '40 William St' in the input sentence.\\nOptions: person, location, organization\",\n",
       " \"With the input text as context, identify the entity type of '40 William St'.\\nOptions: organization, person, location\",\n",
       " \"Using the input sentence as a reference, analyze and specify the entity type of '40 William St'.\\nOptions: organization, location, person\",\n",
       " \"In the context of the input sentence, examine and categorize the entity type of '40 William St'.\\nOptions: location, organization, person\",\n",
       " \"Utilize the input text as context to explore and ascertain the entity type of '40 William St'.\\nOptions: organization, person, location\",\n",
       " \"Leverage the input sentence to evaluate and define the entity type for '40 William St'.\\nOptions: person, organization, location\",\n",
       " \"Considering the input sentence as context, inspect and classify the entity type of '40 William St'.\\nOptions: organization, location, person\",\n",
       " \"With the input sentence as a backdrop, scrutinize and determine the entity type of '40 William St'.\\nOptions: organization, location, person\",\n",
       " \"Interpreting the input sentence as context, specify the entity type for '40 William St'.\\nOptions: person, location, organization\",\n",
       " \"Assessing the input sentence as context, label the entity type of '40 William St'.\\nOptions: organization, location, person\"]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ner_cls_instruct_dataset['train']['instruction'][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f829608-6cfd-488c-ad8f-053013034e97",
   "metadata": {},
   "source": [
    "# FinRED"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "13db75e7-fc60-4ff6-9806-a27d53c1c19a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('FinRED/relations.txt') as f:\n",
    "    relations = [r.strip() for r in f.readlines()]\n",
    "\n",
    "    \n",
    "def get_instruction(sent, tuples, with_orig=True, with_cls=False):\n",
    "    \n",
    "    instructions, inputs, outputs = [], [], []\n",
    "    if with_orig:\n",
    "        instructions.append(f\"Given phrases that describe the relationship between two words/phrases as options, extract the word/phrase pair and the corresponding lexical relationship between them from the input text. The output format should be \\\"relation1: word1, word2; relation2: word3, word4\\\". Options: {', '.join(relations)}.\")\n",
    "        instructions.append(f\"Given the input sentence, please extract the subject and object containing a certain relation in the sentence according to the following relation types, in the format of \\\"relation1: word1, word2; relation2: word3, word4\\\". Relations include: {'; '.join(relations)}.\")\n",
    "        inputs.extend([sent] * 2)\n",
    "        outputs.extend([\"; \".join([f\"{tup[-1]}: {tup[0]}, {tup[1]}\" for tup in tuples])] * 2)\n",
    "    \n",
    "    if with_cls:\n",
    "        for tup in tuples:\n",
    "            instructions.append(f\"Utilize the input text as a context reference, choose the right relationship between {tup[0]} and {tup[1]} from the options. Options: {', '.join(relations)}.\")\n",
    "            instructions.append(f\"What is the relationship between {tup[0]} and {tup[1]} in the context of the input sentence. Choose an answer from: {'; '.join(relations)}.\")\n",
    "            inputs.extend([sent] * 2)\n",
    "            outputs.extend([tup[-1]] * 2)\n",
    "    \n",
    "    return instructions, inputs, outputs\n",
    "\n",
    "\n",
    "def get_finred_dataset(sent_file, tup_file, with_orig, with_cls):\n",
    "\n",
    "    instructions, inputs, outputs = [], [], []\n",
    "\n",
    "    with open(sent_file) as f:\n",
    "        sentences = [s.strip() for s in f.readlines()]\n",
    "    with open(tup_file) as f:\n",
    "        tuples_list = [s.split(' | ') for s in f.readlines()]\n",
    "        \n",
    "    for sent, tuples in zip(sentences, tuples_list):\n",
    "        tuples = [[e.strip() for e in tup.split(' ; ')] for tup in tuples]\n",
    "        \n",
    "        ins, i, o = get_instruction(sent, tuples, with_orig, with_cls)\n",
    "        \n",
    "        instructions.extend(ins)\n",
    "        inputs.extend(i)\n",
    "        outputs.extend(o)\n",
    "        \n",
    "    return Dataset.from_dict({\n",
    "        'input': inputs,\n",
    "        'output': outputs,\n",
    "        'instruction': instructions\n",
    "    })\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "fc080217-5907-4050-b8fc-143729b43c04",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0045130252838134766,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 27558,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/27558 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004259586334228516,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 5112,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/5112 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 27558\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 5112\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = get_finred_dataset('FinRED/train.sent', 'FinRED/train.tup', with_orig=True, with_cls=True)\n",
    "test_dataset = get_finred_dataset('FinRED/test.sent', 'FinRED/test.tup', with_orig=True, with_cls=True)\n",
    "\n",
    "finred_dataset = DatasetDict({\n",
    "    'train': train_dataset,\n",
    "    'test': test_dataset\n",
    "})\n",
    "\n",
    "finred_dataset.save_to_disk('fingpt-finred')\n",
    "finred_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b81ad3e1-6e6f-465c-be23-559b561be89d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.00446009635925293,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 11400,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/11400 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004523754119873047,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 2136,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/2136 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 11400\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 2136\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = get_finred_dataset('FinRED/train.sent', 'FinRED/train.tup', with_orig=True, with_cls=False)\n",
    "test_dataset = get_finred_dataset('FinRED/test.sent', 'FinRED/test.tup', with_orig=True, with_cls=False)\n",
    "\n",
    "finred_dataset = DatasetDict({\n",
    "    'train': train_dataset,\n",
    "    'test': test_dataset\n",
    "})\n",
    "\n",
    "finred_dataset.save_to_disk('fingpt-finred-re')\n",
    "finred_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "700987c5-5561-407d-a87f-ff3ef8a343a0",
   "metadata": {},
   "source": [
    "## CLS Ver. for Zero-shot Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "12b63879-8b4a-4f6e-8e2a-ea769cc5ace4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('FinRED/relations.txt') as f:\n",
    "    all_relations = [r.strip() for r in f.readlines()]\n",
    "\n",
    "\n",
    "def get_instruction(sent, tuples):\n",
    "    \n",
    "    instructions, inputs, outputs = [], [], []\n",
    "    for tup in tuples:        \n",
    "        output = tup[-1].replace('_', ' ').replace(' / ', '/').replace(' or ', '/')\n",
    "        relations = all_relations.copy()\n",
    "        relations.remove(output)\n",
    "        random.shuffle(relations)\n",
    "        relations = relations[:3] + [output]\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"Utilize the input text as a context reference, choose the right relationship between '{tup[0]}' and '{tup[1]}' from the options.\\nOptions: {', '.join(relations)}\")\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"Refer to the input text as context and select the correct relationship between '{tup[0]}' and '{tup[1]}' from the available options.\\nOptions: {', '.join(relations)}\")\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"Take context from the input text and decide on the accurate relationship between '{tup[0]}' and '{tup[1]}' from the options provided.\\nOptions: {', '.join(relations)}\")\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"What is the relationship between '{tup[0]}' and '{tup[1]}' in the context of the input sentence.\\nOptions: {', '.join(relations)}\")\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"In the context of the input sentence, determine the relationship between '{tup[0]}' and '{tup[1]}'.\\nOptions: {', '.join(relations)}\")\n",
    "        random.shuffle(relations)\n",
    "        instructions.append(f\"Analyze the relationship between '{tup[0]}' and '{tup[1]}' within the context of the input sentence.\\nOptions: {', '.join(relations)}\")\n",
    "        inputs.extend([sent] * 6)\n",
    "        outputs.extend([output] * 6)\n",
    "    \n",
    "    return instructions, inputs, outputs\n",
    "\n",
    "\n",
    "def get_finred_dataset(sent_file, tup_file):\n",
    "    \n",
    "    random.seed(0)\n",
    "    instructions, inputs, outputs = [], [], []\n",
    "\n",
    "    with open(sent_file) as f:\n",
    "        sentences = [s.strip() for s in f.readlines()]\n",
    "    with open(tup_file) as f:\n",
    "        tuples_list = [s.split(' | ') for s in f.readlines()]\n",
    "        \n",
    "    for sent, tuples in zip(sentences, tuples_list):\n",
    "        tuples = [[e.strip() for e in tup.split(' ; ')] for tup in tuples]\n",
    "        \n",
    "        ins, i, o = get_instruction(sent, tuples)\n",
    "        \n",
    "        instructions.extend(ins)\n",
    "        inputs.extend(i)\n",
    "        outputs.extend(o)\n",
    "        \n",
    "    return Dataset.from_dict({\n",
    "        'input': inputs,\n",
    "        'output': outputs,\n",
    "        'instruction': instructions\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "22eb8f60-3060-4e83-a785-d3a66c7f60f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004596710205078125,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 48474,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/48474 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004204511642456055,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 8928,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/8928 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 48474\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 8928\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset_instruct = get_finred_dataset('FinRED/train.sent', 'FinRED/train.tup')\n",
    "test_dataset_instruct = get_finred_dataset('FinRED/test.sent', 'FinRED/test.tup')\n",
    "\n",
    "finred_dataset_instruct = DatasetDict({\n",
    "    'train': train_dataset_instruct,\n",
    "    'test': test_dataset_instruct\n",
    "})\n",
    "\n",
    "finred_dataset_instruct.save_to_disk('fingpt-finred-cls-instruct')\n",
    "finred_dataset_instruct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "ac71c924-4372-4b72-95d8-92a063cc9fdb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"Utilize the input text as a context reference, choose the right relationship between 'Apple Inc' and 'Steve Jobs' from the options.\\nOptions: industry, founded by, owner of, currency\",\n",
       " \"Refer to the input text as context and select the correct relationship between 'Apple Inc' and 'Steve Jobs' from the available options.\\nOptions: industry, currency, owner of, founded by\",\n",
       " \"Take context from the input text and decide on the accurate relationship between 'Apple Inc' and 'Steve Jobs' from the options provided.\\nOptions: industry, currency, owner of, founded by\",\n",
       " \"What is the relationship between 'Apple Inc' and 'Steve Jobs' in the context of the input sentence.\\nOptions: currency, founded by, owner of, industry\",\n",
       " \"In the context of the input sentence, determine the relationship between 'Apple Inc' and 'Steve Jobs'.\\nOptions: industry, founded by, owner of, currency\",\n",
       " \"Analyze the relationship between 'Apple Inc' and 'Steve Jobs' within the context of the input sentence.\\nOptions: currency, founded by, owner of, industry\",\n",
       " \"Utilize the input text as a context reference, choose the right relationship between 'Apple Inc' and 'Steve Jobs' from the options.\\nOptions: chief executive officer, headquarters location, subsidiary, industry\",\n",
       " \"Refer to the input text as context and select the correct relationship between 'Apple Inc' and 'Steve Jobs' from the available options.\\nOptions: subsidiary, chief executive officer, headquarters location, industry\",\n",
       " \"Take context from the input text and decide on the accurate relationship between 'Apple Inc' and 'Steve Jobs' from the options provided.\\nOptions: chief executive officer, industry, subsidiary, headquarters location\",\n",
       " \"What is the relationship between 'Apple Inc' and 'Steve Jobs' in the context of the input sentence.\\nOptions: subsidiary, headquarters location, chief executive officer, industry\"]"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finred_dataset_instruct['train']['instruction'][:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb84d245-daef-483d-bf77-8bde1dcc4481",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# ConvFinQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "964e11a4-6f54-42de-bde4-ed8aac19a010",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_confinqa_dataset(json_file):\n",
    "\n",
    "    instructions, inputs, outputs = [], [], []\n",
    "\n",
    "    instruction = 'Read the following texts and table with financial data from an S&P 500 earnings report carefully.' \\\n",
    "    'Based on the question-answer history (if provided), answer the last question. ' \\\n",
    "    'The answer may require mathematical calculation based on the data provided.\\n'\n",
    "\n",
    "    samples = json.load(open(json_file))\n",
    "    for sample in samples:\n",
    "        annos = sample['annotation']\n",
    "        pre_text, post_text = annos['amt_pre_text'], annos['amt_post_text']\n",
    "        table = annos['amt_table'].replace('<td></td>', '<td>-</td>')\n",
    "        context = f'{pre_text} {table} {post_text}\\n'\n",
    "        questions, answers, turn_ind = annos['dialogue_break'], annos['exe_ans_list'], annos['turn_ind']\n",
    "        for i in range(turn_ind):\n",
    "            context += f'Question: {questions[i]}\\n'\n",
    "            context += f'Answer: {answers[i]}\\n'\n",
    "        context += f'Question: {questions[turn_ind]}\\n'\n",
    "        outputs.append(str(answers[turn_ind]))\n",
    "        instructions.append(instruction)\n",
    "        inputs.append(context)\n",
    "        \n",
    "    return Dataset.from_dict({\n",
    "        'input': inputs,\n",
    "        'output': outputs,\n",
    "        'instruction': instructions\n",
    "    })\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "e817bc45-9825-4af6-bd34-94a6307c15fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004643678665161133,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 11104,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4bebad1197d14912ba83461801e19b9a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/11104 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004342794418334961,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 1490,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eeba81a2788b4b9b87f47671a9d6a579",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1490 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 11104\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 1490\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_file = 'ConvFinQA/data/train_turn.json'\n",
    "test_file = 'ConvFinQA/data/dev_turn.json'\n",
    "\n",
    "train_dataset = get_confinqa_dataset(train_file)\n",
    "test_dataset = get_confinqa_dataset(test_file)\n",
    "\n",
    "convfinqa_dataset = DatasetDict({\n",
    "    'train': train_dataset,\n",
    "    'test': test_dataset\n",
    "})\n",
    "\n",
    "convfinqa_dataset.save_to_disk('fingpt-convfinqa')\n",
    "convfinqa_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "babd6ee9-3d28-4ca6-bdc9-4d13cb76c7be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': 'in a new business model such as the retail segment is inherently risky , particularly in light of the significant investment involved , the current economic climate , and the fixed nature of a substantial portion of the retail segment\\'s operating expenses . results for this segment are dependent upon a number of risks and uncertainties , some of which are discussed below under the heading \"factors that may affect future results and financial condition.\" backlog in the company\\'s experience , the actual amount of product backlog at any particular time is not a meaningful indication of its future business prospects . in particular , backlog often increases in anticipation of or immediately following new product introductions because of over- ordering by dealers anticipating shortages . backlog often is reduced once dealers and customers believe they can obtain sufficient supply . because of the foregoing , backlog cannot be considered a reliable indicator of the company\\'s ability to achieve any particular level of revenue or financial performance . further information regarding the company\\'s backlog may be found below under the heading \"factors that may affect future results and financial condition.\" gross margin gross margin for the three fiscal years ended september 28 , 2002 are as follows ( in millions , except gross margin percentages ) : gross margin increased to 28% ( 28 % ) of net sales in 2002 from 23% ( 23 % ) in 2001 . as discussed below , gross margin in 2001 was unusually low resulting from negative gross margin of 2% ( 2 % ) experienced in the first quarter of 2001 . as a percentage of net sales , the company\\'s quarterly gross margins declined during fiscal 2002 from 31% ( 31 % ) in the first quarter down to 26% ( 26 % ) in the fourth quarter . this decline resulted from several factors including a rise in component costs as the year progressed and aggressive pricing by the company across its products lines instituted as a result of continued pricing pressures in the personal computer industry . the company anticipates that its gross margin and the gross margin of the overall personal computer industry will remain under pressure throughout fiscal 2003 in light of weak economic conditions , flat demand for personal computers in general , and the resulting pressure on prices . the foregoing statements regarding anticipated gross margin in 2003 and the general demand for personal computers during 2003 are forward- looking . gross margin could differ from anticipated levels because of several factors , including certain of those set forth below in the subsection entitled \"factors that may affect future results and financial condition.\" there can be no assurance that current gross margins will be maintained , targeted gross margin levels will be achieved , or current margins on existing individual products will be maintained . in general , gross margins and margins on individual products will remain under significant downward pressure due to a variety of factors , including continued industry wide global pricing pressures , increased competition , compressed product life cycles , potential increases in the cost and availability of raw material and outside manufacturing services , and potential changes to the company\\'s product mix , including higher unit sales of consumer products with lower average selling prices and lower gross margins . in response to these downward pressures , the company expects it will continue to take pricing actions with respect to its products . gross margins could also be affected by the company\\'s ability to effectively manage quality problems and warranty costs and to stimulate demand for certain of its products . the company\\'s operating strategy and pricing take into account anticipated changes in foreign currency exchange rates over time ; however , the company\\'s results of operations can be significantly affected in the short-term by fluctuations in exchange rates . the company orders components for its products and builds inventory in advance of product shipments . because the company\\'s markets are volatile and subject to rapid technology and price changes , there is a risk the company will forecast incorrectly and produce or order from third parties excess or insufficient inventories of particular products or components . the company\\'s operating results and financial condition have been in the past and may in the future be materially adversely affected by the company\\'s ability to manage its inventory levels and outstanding purchase commitments and to respond to short-term shifts in customer demand patterns . gross margin declined to 23% ( 23 % ) of net sales in 2001 from 27% ( 27 % ) in 2000 . this decline resulted primarily from gross margin of negative 2% ( 2 % ) experienced during the first quarter of 2001 compared to 26% ( 26 % ) gross margin for the same quarter in 2000 . in addition to lower than normal net . <table class=\\'wikitable\\'><tr><td>1</td><td>-</td><td>2002</td><td>2001</td><td>2000</td></tr><tr><td>2</td><td>net sales</td><td>$ 5742</td><td>$ 5363</td><td>$ 7983</td></tr><tr><td>3</td><td>cost of sales</td><td>4139</td><td>4128</td><td>5817</td></tr><tr><td>4</td><td>gross margin</td><td>$ 1603</td><td>$ 1235</td><td>$ 2166</td></tr><tr><td>5</td><td>gross margin percentage</td><td>28% ( 28 % )</td><td>23% ( 23 % )</td><td>27% ( 27 % )</td></tr></table> .\\nQuestion: what was the total of net sales in 2001?\\nAnswer: 5363.0\\nQuestion: and what was that in 2000?\\n',\n",
       " 'output': '7983.0',\n",
       " 'instruction': 'Read the following texts and table with financial data from an S&P 500 earnings report carefully.Based on the question-answer history (if provided), answer the last question. The answer may require mathematical calculation based on the data provided.\\n'}"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "convfinqa_dataset['train'][9]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44f2a7ea-b832-462f-8103-eca6ee4046bb",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# FinEval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "7f42f8eb-4174-46e7-b2fe-7c8a6487f084",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004954338073730469,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 1056,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3784d06ad3f94b04a2e185d139aa9f33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/1056 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0040547847747802734,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 265,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ae0d31a357e74c6d825773a7c8f81cb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/265 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 1056\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['input', 'output', 'instruction'],\n",
       "        num_rows: 265\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "csv_list = glob('FinEval/val/*.csv') + glob('FinEval/dev/*.csv')\n",
    "subject_mapping = json.load(open('FinEval/subject_mapping.json'))\n",
    "\n",
    "instructions, inputs, outputs = [], [], []\n",
    "\n",
    "for csv_file in csv_list:\n",
    "    subject = subject_mapping[csv_file.split('/')[-1][:-8]][1]\n",
    "    df = pd.read_csv(csv_file)\n",
    "    for index, row in df.iterrows():\n",
    "        instructions.append(f'以下是中国关于{subject}考试的单项选择题，请选出其中的正确答案。')\n",
    "        inputs.append(f\"{row['question']}\\nA. {row['A']}\\nB. {row['B']}\\nC. {row['C']}\\nD. {row['D']}\\n\")\n",
    "        outputs.append(f\"{row['answer']}. {row[row['answer']]}\")\n",
    "        \n",
    "fineval_dataset = Dataset.from_dict({\n",
    "    'input': inputs,\n",
    "    'output': outputs,\n",
    "    'instruction': instructions\n",
    "})\n",
    "fineval_dataset = fineval_dataset.train_test_split(0.2, seed=42)\n",
    "fineval_dataset.save_to_disk('fingpt-fineval')\n",
    "fineval_dataset\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "cd766190-4cdc-4f4d-93d1-d2bfd0627093",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input': '研究社会资本再生产的出发点是____。\\nA. 货币资本\\nB. 生产资本\\nC. 流通资本\\nD. 社会总产品\\n',\n",
       " 'output': 'D. 社会总产品',\n",
       " 'instruction': '以下是中国关于政治经济学考试的单项选择题，请选出其中的正确答案。'}"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fineval_dataset['train'][3]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c87615b-95e9-4dfd-b2a4-fa5ed900c2ba",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true,
    "tags": []
   },
   "source": [
    "# FiQA QA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "a410166c-265f-4843-afba-35feb14b6063",
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = pd.read_csv('FiQA_train/FiQA_train_doc_final.tsv', sep='\\t')\n",
    "questions = pd.read_csv('FiQA_train/FiQA_train_question_final.tsv', sep='\\t')\n",
    "qa_pairs = pd.read_csv('FiQA_train/FiQA_train_question_doc_final.tsv', sep='\\t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "96230dce-bac3-43f1-ba19-48eeeb03a347",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.005600690841674805,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Saving the dataset (0/1 shards)",
       "rate": null,
       "total": 17110,
       "unit": " examples",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b1cce16e3de4d49ae7ac66647e093a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Saving the dataset (0/1 shards):   0%|          | 0/17110 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input', 'output', 'instruction'],\n",
       "    num_rows: 17110\n",
       "})"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "doc_dict, question_dict = {}, {}\n",
    "for i, row in docs.iterrows():\n",
    "    doc_dict[row['docid']] = row['doc']\n",
    "for i, row in questions.iterrows():\n",
    "    question_dict[row['qid']] = row['question']\n",
    "    \n",
    "instruction_templates = [\n",
    "    \"Utilize your financial knowledge, give your answer or opinion to the input question or subject . Answer format is not limited.\",\n",
    "    \"Offer your insights or judgment on the input financial query or topic using your financial expertise. Reply as normal question answering\",\n",
    "    \"Based on your financial expertise, provide your response or viewpoint on the given financial question or topic. The response format is open.\",\n",
    "    \"Share your insights or perspective on the financial matter presented in the input.\",\n",
    "    \"Offer your thoughts or opinion on the input financial query or topic using your financial background.\"\n",
    "]\n",
    "\n",
    "inputs, outputs, instructions = [], [], []\n",
    "for i, row in qa_pairs.iterrows():\n",
    "    qid, docid = row['qid'], row['docid']\n",
    "    q = str(question_dict[qid])\n",
    "    doc = str(doc_dict[docid])\n",
    "    inputs.append(q)\n",
    "    outputs.append(doc)\n",
    "    instructions.append(instruction_templates[i%5])\n",
    "\n",
    "fiqa_qa_dataset = Dataset.from_dict({\n",
    "    'input': inputs,\n",
    "    'output': outputs,\n",
    "    'instruction': instructions\n",
    "})\n",
    "fiqa_qa_dataset.save_to_disk('fingpt-fiqa_qa')\n",
    "fiqa_qa_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a8558d-97a9-4b90-8835-cbfa32e28125",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07cfc13-4b1e-4672-ab76-64f90c9d03e1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
